from .AC_TRANS import *
from .AC_FUN import AC_FUN
import math
import random
import time
# 颜色匹配
class ColorAdapter(AC_FUN):
    @classmethod
    def INPUT_TYPES(self):

        return {
            "required": {
                "image": ("IMAGE", ),  
                "color_ref_image": ("IMAGE", ),  
                "opacity": ("INT", {"default": 75, "min": 0, "max": 100, "step": 1}),  
            },
            "optional": {
            }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = 'color_adapter'
    OUTPUT_NODE = True

    def color_adapter(self, image, color_ref_image, opacity):
        ret_images = []
        l_images = []
        r_images = []
        for l in image:
            l_images.append(torch.unsqueeze(l, 0))
        for r in color_ref_image:
            r_images.append(torch.unsqueeze(r, 0))
        for i in range(len(l_images)):
            _image = l_images[i]
            _ref = r_images[i] if len(ret_images) > i else r_images[-1]

            _canvas = ac_tensor2pil(_image).convert('RGB')
            ret_image = color_adapter(_canvas, ac_tensor2pil(_ref).convert('RGB'))
            ret_image = chop_image(_canvas, ret_image, blend_mode='normal', opacity=opacity)

            ret_images.append(ac_pil2tensor(ret_image))
        return (torch.cat(ret_images, dim=0),)

# 高亮材质
class ColorCorrectBrightness(AC_FUN):
    @classmethod
    def INPUT_TYPES(self):

        return {
            "required": {
                "image": ("IMAGE", ),  #
                "brightness": ("FLOAT", {"default": 1, "min": 0.0, "max": 3, "step": 0.01}),
                "contrast": ("FLOAT", {"default": 1, "min": 0.0, "max": 3, "step": 0.01}),
                "saturation": ("FLOAT", {"default": 1, "min": 0.0, "max": 3, "step": 0.01}),
            },
            "optional": {
            }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = 'color_correct_brightness'
    def color_correct_brightness(self, image, brightness, contrast, saturation):

        ret_images = []

        for i in image:
            i = torch.unsqueeze(i,0)

            _image = ac_tensor2pil(i).convert('RGB')
            if brightness != 1:
                brightness_image = ImageEnhance.Brightness(_image)
                _image = brightness_image.enhance(factor=brightness)
            if contrast != 1:
                contrast_image = ImageEnhance.Contrast(_image)
                _image = contrast_image.enhance(factor=contrast)
            if saturation != 1:
                color_image = ImageEnhance.Color(_image)
                _image = color_image.enhance(factor=saturation)
            ret_images.append(ac_pil2tensor(_image))
        return (torch.cat(ret_images, dim=0),)
# 风格映射
colormap_list = ['秋季', '骨感', '喷射', '冬天', '彩虹', '海洋',
                 '夏天', '春天', '冬天', 'HSV', '粉红', '热辣',
                 '岩浆', '地狱', '等离子', '翠绿', '文明',
                 '黄昏', '暮光', '加速', '深绿']

class ColorMap(AC_FUN):
    @classmethod
    def INPUT_TYPES(self):

        return {
            "required": {
                "image": ("IMAGE", ),
                "color_map": (colormap_list,),
                "opacity": ("INT", {"default": 100, "min": 0, "max": 100, "step": 1}), 
            },
            "optional": {
            }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = 'map'
    def map(self, image, color_map, opacity
                  ):

        ret_images = []

        for i in image:
            i = torch.unsqueeze(i, 0)
            _canvas = ac_tensor2pil(i)
            _image = image_to_colormap(_canvas, colormap_list.index(color_map))
            ret_image = chop_image(_canvas, _image, 'normal', opacity)

            ret_images.append(ac_pil2tensor(ret_image))
        return (torch.cat(ret_images, dim=0),)

# 颜色拾取器
class ColorPicker(AC_FUN):
    @classmethod
    def INPUT_TYPES(self):
        mode_list = ['HEX', 'DEC']
        return {
            "required": {
                "color": ("COLOR", {"default": "#FFFFFF"},),
                "mode": (mode_list,),
            },
            "optional": {
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("value",)
    FUNCTION = 'picker'
    def picker(self, color, mode):
        ret = color
        if mode == 'DEC':
            ret = Hex_to_RGB(ret)
        return (ret,)    

# 高斯模糊
class AC_GausBlur(AC_FUN):

    def __init__(self):
        pass

    @classmethod
    def INPUT_TYPES(self):

        return {
            "required": {
                "image": ("IMAGE", ),  #
                "blur": ("INT", {"default": 20, "min": 1, "max": 999, "step": 1}),  
            },
            "optional": {
            }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = 'gaus_blur'

    OUTPUT_NODE = True

    def gaus_blur(self, image, blur):

        ret_images = []

        for i in image:
            _canvas = ac_tensor2pil(torch.unsqueeze(i, 0)).convert('RGB')

            ret_images.append(ac_pil2tensor(gaussian_blur(_canvas, blur)))
        return (torch.cat(ret_images, dim=0),)


# 图片混合
chop_mode = ['法线', '叠加', '屏幕', '添加', '减去', '差分', '深色', '亮度',
             '颜色高亮', '颜色减淡', '线性高亮', '线性减淡', '覆盖',
             '暖光', '硬光', '环境光', '固定光', '线性光', '强制混合']
class ImageBlend(AC_FUN):
    def __init__(self,):
        pass
    
    @classmethod
    def INPUT_TYPES(self,):
        return {
            "required": {
                "background_image": ("IMAGE", ),  
                "layer_image": ("IMAGE",),  
                "invert_mask": ("BOOLEAN", {"default": True}),  
                "blend_mode": (chop_mode,),  
                "opacity": ("INT", {"default": 100, "min": 0, "max": 100, "step": 1}),  
            "optional": {
                "layer_mask": ("MASK",), 
            }
        }}
    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = 'image_blend'
    
    OUTPUT_NODE = True
    def image_blend(self, background_image, layer_image,
                  invert_mask, blend_mode, opacity,
                  layer_mask=None
                  ):
        b_images = []
        l_images = []
        l_masks = []
        ret_images = []
        for b in background_image:
            b_images.append(torch.unsqueeze(b, 0))
        for l in layer_image:
            l_images.append(torch.unsqueeze(l, 0))
            m = ac_tensor2pil(l)
            if m.mode == 'RGBA':
                l_masks.append(m.split()[-1])
        else:
            l_masks.append(Image.new('L', m.size, 'white'))
        if layer_mask is not None:
            if layer_mask.dim() == 2:
                layer_mask = torch.unsqueeze(layer_mask, 0)
        l_masks = []
        for m in layer_mask:
            if invert_mask:
                m = 1 - m
            l_masks.append(ac_tensor2pil(torch.unsqueeze(m, 0)).convert('L'))
        max_batch = max(len(b_images), len(l_images), len(l_masks))
        for i in range(max_batch):
            background_image = b_images[i] if i < len(b_images) else b_images[-1]
            layer_image = l_images[i] if i < len(l_images) else l_images[-1]
            _mask = l_masks[i] if i < len(l_masks) else l_masks[-1]
            _canvas = ac_tensor2pil(background_image).convert('RGB')
            _layer = ac_tensor2pil(layer_image).convert('RGB')
            if _mask.size != _layer.size:
                _mask = Image.new('L', _layer.size, 'white')
            
           
            _comp = chop_image(_canvas, _layer, blend_mode, opacity)
            _canvas.paste(_comp, mask=_mask)
            ret_images.append(ac_pil2tensor(_canvas))
        return (torch.cat(ret_images, dim=0),)

# 图像通道合并
class ImageChannelMerge(AC_FUN):
    @classmethod
    def INPUT_TYPES(self):
        channel_mode = ['RGBA', 'YCbCr', 'LAB', 'HSV']
        return {
            "required": {
                "channel_1": ("IMAGE", ),  #
                "channel_2": ("IMAGE",),  #
                "channel_3": ("IMAGE",),  #
                "mode": (channel_mode,),  # 通道设置
            },
            "optional": {
                "channel_4": ("IMAGE",),  #
            }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = 'image_channel_merge'
    OUTPUT_NODE = True

    def image_channel_merge(self, channel_1, channel_2, channel_3, mode, channel_4=None):

        c1_images = []
        c2_images = []
        c3_images = []
        c4_images = []
        ret_images = []

        width, height = ac_tensor2pil(torch.unsqueeze(channel_1[0], 0)).size
        for c in channel_1:
            c1_images.append(torch.unsqueeze(c, 0))
        for c in channel_2:
            c2_images.append(torch.unsqueeze(c, 0))
        for c in channel_3:
            c3_images.append(torch.unsqueeze(c, 0))
        if channel_4 is not None:
            for c in channel_4:
                c4_images.append(torch.unsqueeze(c, 0))
        else:
            c4_images.append(Image.new('L', size=(width, height), color='white'))

        max_batch = max(len(c1_images), len(c2_images), len(c3_images), len(c4_images))
        for i in range(max_batch):
            c_1 = c1_images[i] if i < len(c1_images) else c1_images[-1]
            c_2 = c2_images[i] if i < len(c2_images) else c2_images[-1]
            c_3 = c3_images[i] if i < len(c3_images) else c3_images[-1]
            c_4 = c4_images[i] if i < len(c4_images) else c4_images[-1]
            ret_image = image_channel_merge((ac_tensor2pil(c_1), ac_tensor2pil(c_2), ac_tensor2pil(c_3), ac_tensor2pil(c_4)), mode)

            ret_images.append(ac_pil2tensor(ret_image))
        return (torch.cat(ret_images, dim=0),)

# 图像通道切割
class ImageChannelSplit(AC_FUN):
    @classmethod
    def INPUT_TYPES(self):
        channel_mode = ['RGBA', 'YCbCr', 'LAB', 'HSV']
        return {
            "required": {
                "image": ("IMAGE", ),  
                "mode": (channel_mode,),  
            },
            "optional": {
            }
        }

    RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", "IMAGE",)
    RETURN_NAMES = ("channel_1", "channel_2", "channel_3", "channel_4",)
    FUNCTION = 'image_channel_split'
    OUTPUT_NODE = True

    def image_channel_split(self, image, mode):
        c1_images = []
        c2_images = []
        c3_images = []
        c4_images = []

        for i in image:
            i = torch.unsqueeze(i, 0)
            _image = ac_tensor2pil(i).convert('RGBA')
            channel1, channel2, channel3, channel4 = image_channel_split(_image, mode)
            c1_images.append(ac_pil2tensor(channel1))
            c2_images.append(ac_pil2tensor(channel2))
            c3_images.append(ac_pil2tensor(channel3))
            c4_images.append(ac_pil2tensor(channel4))
        return (torch.cat(c1_images, dim=0), torch.cat(c2_images, dim=0), torch.cat(c3_images, dim=0), torch.cat(c4_images, dim=0),)

# 图像alpha通道混合
class ImageCombineAlpha(AC_FUN):
    @classmethod
    def INPUT_TYPES(self):
        return {
            "required": {
                "RGB_image": ("IMAGE", ), 
                "mask": ("MASK",), 
            },
            # "optional": {
            # }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("RGBA_image",)
    FUNCTION = 'image_combine_alpha'

    def image_combine_alpha(self, RGB_image, mask):

        ret_images = []
        input_images = []
        input_masks = []

        for i in RGB_image:
            input_images.append(torch.unsqueeze(i, 0))
        if mask.dim() == 2:
            mask = torch.unsqueeze(mask, 0)
        for m in mask:
            input_masks.append(torch.unsqueeze(m, 0))

        max_batch = max(len(input_images), len(input_masks))
        for i in range(max_batch):
            _image = input_images[i] if i < len(input_images) else input_images[-1]
            _mask = input_masks[i] if i < len(input_masks) else input_masks[-1]
            r, g, b, _ = image_channel_split(ac_tensor2pil(_image).convert('RGB'), 'RGB')
            ret_image = image_channel_merge((r, g, b, ac_tensor2pil(_mask).convert('L')), 'RGBA')

            ret_images.append(ac_pil2tensor(ret_image))
        return (torch.cat(ret_images, dim=0),)

# 图片拉伸
class AC_ImageScaleRestore(AC_FUN):
    @classmethod
    def INPUT_TYPES(self):
        method_mode = ['lanczos', 'bicubic', 'hamming', 'bilinear', 'box', 'nearest']
        return {
            "required": {
                "image": ("IMAGE", ),  #
                "scale": ("FLOAT", {"default": 1, "min": 0.01, "max": 100, "step": 0.01}),
                "method": (method_mode,),
                "scale_by_longest_side": ("BOOLEAN", {"default": False}),
                "longest_side": ("INT", {"default": 1024, "min": 4, "max": 999999, "step": 1}),
            },
            "optional": {
                "mask": ("MASK",),  #
                "original_size": ("BOX",),
            }
        }

    RETURN_TYPES = ("IMAGE", "MASK", "BOX",)
    RETURN_NAMES = ("image", "mask", "original_size")
    FUNCTION = 'image_scale_restore'

    def image_scale_restore(self, image, scale, method,
                            scale_by_longest_side, longest_side,
                            mask = None,  original_size = None
                            ):
        l_images = []
        l_masks = []
        ret_images = []
        ret_masks = []
        for l in image:
            l_images.append(torch.unsqueeze(l, 0))
            m = ac_tensor2pil(l)
            if m.mode == 'RGBA':
                l_masks.append(m.split()[-1])

        if mask is not None:
            if mask.dim() == 2:
                mask = torch.unsqueeze(mask, 0)
            l_masks = []
            for m in mask:
                l_masks.append(ac_tensor2pil(torch.unsqueeze(m, 0)).convert('L'))

        max_batch = max(len(l_images), len(l_masks))

        orig_width, orig_height = ac_tensor2pil(l_images[0]).size
        if original_size is not None:
            target_width = original_size[0]
            target_height = original_size[1]
        else:
            target_width = int(orig_width * scale)
            target_height = int(orig_height * scale)
            if scale_by_longest_side:
                if orig_width > orig_height:
                    target_width = longest_side
                    target_height = int(target_width * orig_height / orig_width)
                else:
                    target_height = longest_side
                    target_width = int(target_height * orig_width / orig_height)
        if target_width < 4:
            target_width = 4
        if target_height < 4:
            target_height = 4
        resize_sampler = Image.LANCZOS
        if method == "bicubic":
            resize_sampler = Image.BICUBIC
        elif method == "hamming":
            resize_sampler = Image.HAMMING
        elif method == "bilinear":
            resize_sampler = Image.BILINEAR
        elif method == "box":
            resize_sampler = Image.BOX
        elif method == "nearest":
            resize_sampler = Image.NEAREST

        for i in range(max_batch):

            _image = l_images[i] if i < len(l_images) else l_images[-1]

            _canvas = ac_tensor2pil(_image).convert('RGB')
            ret_image = _canvas.resize((target_width, target_height), resize_sampler)
            ret_mask = Image.new('L', size=ret_image.size, color='white')
            if mask is not None:
                _mask = l_masks[i] if i < len(l_masks) else l_masks[-1]
                ret_mask = _mask.resize((target_width, target_height), resize_sampler)

            ret_images.append(ac_pil2tensor(ret_image))
            ret_masks.append(image2mask(ret_mask))
        return (torch.cat(ret_images, dim=0), torch.cat(ret_masks, dim=0), [orig_width, orig_height],)

# 动态模糊
class AC_MotionBlur(AC_FUN):
    @classmethod
    def INPUT_TYPES(self):

        return {
            "required": {
                "image": ("IMAGE", ),  #
                "angle": ("INT", {"default": 0, "min": -90, "max": 90, "step": 1}),  # 角度
                "blur": ("INT", {"default": 20, "min": 1, "max": 999, "step": 1}),  # 模糊
            },
            "optional": {
            }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = 'motion_blur'

    def motion_blur(self, image, angle, blur):

        ret_images = []

        for i in image:

            _canvas = ac_tensor2pil(torch.unsqueeze(i, 0)).convert('RGB')

            ret_images.append(ac_pil2tensor(motion_blur(_canvas, angle, blur)))

        return (torch.cat(ret_images, dim=0),)

# 像素wiggle
class AC_Picture_Wiggle(AC_FUN):
    @classmethod
    def INPUT_TYPES(self):
        channel_mode = ['RGB', 'RBG', 'BGR', 'BRG', 'GBR', 'GRB']
        return {
            "required": {
                "image": ("IMAGE", ),  #
                "distance": ("INT", {"default": 20, "min": 1, "max": 999, "step": 1}),  # 距离
                "angle": ("FLOAT", {"default": 40, "min": -360, "max": 360, "step": 0.1}),  # 角度
                "mode": (channel_mode,),  # 模式
            },
            "optional": {
            }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = 'channel_shake'

    def channel_shake(self, image, distance, angle, mode, ):

        ret_images = []

        for i in image:
            i = torch.unsqueeze(i, 0)
            _canvas = ac_tensor2pil(i).convert('RGB')
            R, G, B = _canvas.split()
            x = int(math.cos(angle) * distance)
            y = int(math.sin(angle) * distance)
            if mode.startswith('R'):
                R = shift_image(R.convert('RGB'), -x, -y).convert('L')
            if mode.startswith('G'):
                G = shift_image(G.convert('RGB'), -x, -y).convert('L')
            if mode.startswith('B'):
                B = shift_image(B.convert('RGB'), -x, -y).convert('L')
            if mode.endswith('R'):
                R = shift_image(R.convert('RGB'), x, y).convert('L')
            if mode.endswith('G'):
                G = shift_image(G.convert('RGB'), x, y).convert('L')
            if mode.endswith('B'):
                B = shift_image(B.convert('RGB'), x, y).convert('L')

            ret_image = Image.merge('RGB', [R, G, B])
            ret_images.append(ac_pil2tensor(ret_image))

        return (torch.cat(ret_images, dim=0),)

# 获取颜色集合
class AC_GetColorTone(AC_FUN):
    @classmethod
    def INPUT_TYPES(self):
        mode_list = ['main_color', 'average']
        return {
            "required": {
                "image": ("IMAGE", ),  #
                "mode": (mode_list,),  # 主色/平均色
            },
            # "optional": {
            # }
        }

    RETURN_TYPES = ("STRING", "LIST")
    RETURN_NAMES = ("RGB HEX", "HSV color")
    FUNCTION = 'get_color_tone'

    def get_color_tone(self, image, mode,):
        if image.shape[0] > 0:
            image = torch.unsqueeze(image[0], 0)
        _canvas = ac_tensor2pil(image).convert('RGB')
        _canvas = gaussian_blur(_canvas, int((_canvas.width + _canvas.height) / 200))
        if mode == 'main_color':
            ret_color = get_image_color_tone(_canvas)
        else:
            ret_color = get_image_color_average(_canvas)
        hsv_color = RGB_to_HSV(Hex_to_RGB(ret_color))

        return (ret_color, hsv_color)

# 渐变色
class AC_ColorOverlay(AC_FUN):
    @classmethod
    def INPUT_TYPES(self):

        return {
            "required": {
                "background_image": ("IMAGE", ),  
                "layer_image": ("IMAGE",),  
                "invert_mask": ("BOOLEAN", {"default": True}),  
                "blend_mode": (chop_mode,),  
                "opacity": ("INT", {"default": 100, "min": 0, "max": 100, "step": 1}),  
                "color": ("STRING", {"default": "#FFBF30"}),  
            },
            "optional": {
                "layer_mask": ("MASK",),  
            }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = 'color_overlay'

    def color_overlay(self, background_image, layer_image,
                  invert_mask, blend_mode, opacity, color,
                  layer_mask=None
                  ):

        b_images = []
        l_images = []
        l_masks = []
        ret_images = []
        for b in background_image:
            b_images.append(torch.unsqueeze(b, 0))
        for l in layer_image:
            l_images.append(torch.unsqueeze(l, 0))
            m = ac_tensor2pil(l)
            if m.mode == 'RGBA':
                l_masks.append(m.split()[-1])
        if layer_mask is not None:
            if layer_mask.dim() == 2:
                layer_mask = torch.unsqueeze(layer_mask, 0)
            l_masks = []
            for m in layer_mask:
                if invert_mask:
                    m = 1 - m
                l_masks.append(ac_tensor2pil(torch.unsqueeze(m, 0)).convert('L'))
        if len(l_masks) == 0:
            return (background_image,)

        max_batch = max(len(b_images), len(l_images), len(l_masks))
        _color = Image.new("RGB", ac_tensor2pil(l_images[0]).size, color=color)
        for i in range(max_batch):
            background_image = b_images[i] if i < len(b_images) else b_images[-1]
            layer_image = l_images[i] if i < len(l_images) else l_images[-1]
            _mask = l_masks[i] if i < len(l_masks) else l_masks[-1]
            # preprocess
            _canvas = ac_tensor2pil(background_image).convert('RGB')
            _layer = ac_tensor2pil(layer_image).convert('RGB')
            if _mask.size != _layer.size:
                _mask = Image.new('L', _layer.size, 'white')
            # 合成layer
            _comp = chop_image(_layer, _color, blend_mode, opacity)
            _canvas.paste(_comp, mask=_mask)
            ret_images.append(ac_pil2tensor(_canvas))

        return (torch.cat(ret_images, dim=0),)

# 图像曝光
blend_mode = 'screen'

class AC_LightLeak(AC_FUN):
    @classmethod
    def INPUT_TYPES(self):
        light_list = ['random', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10',
                      '11', '12', '13', '14', '15', '16', '17', '18', '19', '20',
                      '21', '22', '23', '24', '25', '26', '27', '28', '29', '30',
                      '31', '32']
        corner_list = ['left_top', 'right_top', 'left_bottom', 'right_bottom']
        return {
            "required": {
                "image": ("IMAGE", ),
                "light": (light_list,),
                "corner": (corner_list,),
                "hue": ("INT", {"default": 0, "min": -255, "max": 255, "step": 1}),
                "saturation": ("INT", {"default": 0, "min": -255, "max": 255, "step": 1}),
                "opacity": ("INT", {"default": 100, "min": 0, "max": 100, "step": 1})
            },
            # "optional": {
            # }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = 'light_leak'

    def light_leak(self, image, light, corner, hue, saturation, opacity):

        ret_images = []
        light_leak_images = load_light_leak_images()
        if light == 'random':
            random.seed(time.time())
            light_index = random.randint(0,31)
        else:
            light_index = int(light) - 1

        for i in image:
            i = torch.unsqueeze(i, 0)
            _canvas = ac_tensor2pil(i).convert('RGB')
            _light = light_leak_images[light_index]
            if _canvas.width < _canvas.height:
                _light = _light.transpose(Image.ROTATE_90).transpose(Image.FLIP_TOP_BOTTOM)
            if corner == 'right_top':
                _light = _light.transpose(Image.FLIP_LEFT_RIGHT)
            elif corner == 'left_bottom':
                _light = _light.transpose(Image.FLIP_TOP_BOTTOM)
            elif corner == 'right_bottom':
                _light = _light.transpose(Image.ROTATE_180)
            if hue != 0 or saturation != 0:
                _h, _s, _v = _light.convert('HSV').split()
                if hue != 0:
                    _h = image_hue_offset(_h, hue)
                if saturation != 0:
                    _s = image_gray_offset(_s, saturation)
                _light = image_channel_merge((_h, _s, _v), 'HSV')
            resize_sampler = Image.BILINEAR
            _light = fit_resize_image(_light, _canvas.width, _canvas.height, fit='crop', resize_sampler=resize_sampler)
            ret_image = chop_image(_canvas, _light, blend_mode=blend_mode, opacity = opacity)
            ret_images.append(ac_pil2tensor(ret_image))

        return (torch.cat(ret_images, dim=0),)


# # 图片旋转
# class AC_Rotate(AC_FUN):
#     @classmethod
#     def INPUT_TYPES(self):
#         return {
#             "required": {
#                 "image": ("IMAGE", ),
#                 "angle": ("INT", {"default": 0, "min": -360, "max": 360, "step": 1}),
#             },
#             "optional": {
#                 "background_color": ("STRING", {"default": "#000000"}),
#             }}
#     RETURN_TYPES = ("IMAGE",)   
#     RETURN_NAMES = ("image",)
#     FUNCTION = "ac_rotate"